Skip to content

Record: CROWN-Q + Full GPTQ + SWA/EMA Blend — val_bpb 1.1186 (3-seed mean)#693

Open
EthanYangTW wants to merge 5 commits intoopenai:mainfrom
EthanYangTW:submission/crownq-record-v2
Open

Record: CROWN-Q + Full GPTQ + SWA/EMA Blend — val_bpb 1.1186 (3-seed mean)#693
EthanYangTW wants to merge 5 commits intoopenai:mainfrom
EthanYangTW:submission/crownq-record-v2

Conversation

@EthanYangTW
Copy link

Summary

  • CROWN-Q: Novel curvature-weighted quantization variance penalty during warmdown. Encourages weights into flat minima where int6 quantization causes less damage.
  • Full Cholesky GPTQ: Hessian-aware quantization with act-order. GPTQ runs after the 585s training phase as part of model export.
  • SWA/EMA 50/50 blend: Stochastic Weight Averaging blended with EMA (0.997).
  • Architecture: 11L, 512d, GQA 8/4, MLP 3x LeakyReLU(0.5)^2, XSA last 4 layers (7-10), VRL, BigramHash 3072.
  • Eval: Sliding window stride=64, pure inference. No TTT (TTT_ENABLED=0).

Results

Seed Steps Sliding BPB Artifact
1337 6613 1.1189 15,945,134
42 6612 1.1189 15,947,742
7 6613 1.1179 15,938,790
Mean 1.1186
Std 0.0006
  • Training: 585s wallclock, 87ms/step (FA3 Hopper)
  • All artifacts < 16,000,000 bytes
  • Sliding window eval time: ~75s

What is CROWN-Q?

Training-time penalty per weight row: lambda * mean(w^2) * delta^2 / 12 where delta = row_max / 15. The CROWN-Q step size (row_max/15) is intentionally larger than the actual quantizer step size (row_max/31, clip_range=31) — this over-penalization pushes weights further into flat basins, providing extra robustness margin against quantization damage. Applied only during warmdown when QAT is active. Zero eval-time cost.

Why No TTT?

AdamW TTT destroys GPTQ-quantized weights (+0.077 BPB degradation). Full-weight AdamW at lr=0.002 on quantized models causes the carefully optimized GPTQ weight placement to diverge. SGD TTT is neutral-to-harmful. TTT_ENABLED is set to 0 in the submitted code.

Compliance

  • Training: 585s wallclock (under 600s)
  • GPTQ calibration uses training data only
  • Eval is pure inference (sliding window), no TTT
  • All artifacts <= 16,000,000 bytes

…=1.1162)

int5 GPTQ quantization with Hessian-aware error compensation enables 33.6M
params in 16MB. Soft-Round QAT (differentiable tanh rounding, alpha 1→16)
replaces STE for better training quality at zero cost.

3-seed results:
- Seed 1337: val_bpb=1.1155, artifact=15.82MB
- Seed 42:   val_bpb=1.1163, artifact=15.42MB
- Seed 7:    val_bpb=1.1167, artifact=15.37MB
- Mean: 1.1162 (std 0.0006)
Copilot AI review requested due to automatic review settings March 25, 2026 07:26
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a new record submission folder under records/track_10min_16mb/ capturing an experiment that combines CROWN-Q warmdown regularization, full (Cholesky) GPTQ export-time quantization, and a 50/50 SWA+EMA weight blend, with sliding-window evaluation.

Changes:

  • Added a self-contained train_gpt.py implementing CROWN-Q, SWA/EMA blending, GPTQ calibration/quantization, and sliding-window eval (plus optional TTT routines).
  • Added record metadata (submission.json) and documentation (README.md) describing the method and results.
  • Added three seed logs intended to substantiate the reported mean/stdev.

Reviewed changes

Copilot reviewed 3 out of 6 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
records/track_10min_16mb/2026-03-25_CROWNQ_GPTQ_SlidingWindow/train_gpt.py Core training/export/eval script for the submission (CROWN-Q, GPTQ, SWA/EMA, sliding-window eval).
records/track_10min_16mb/2026-03-25_CROWNQ_GPTQ_SlidingWindow/README.md Submission write-up, config notes, and results summary.
records/track_10min_16mb/2026-03-25_CROWNQ_GPTQ_SlidingWindow/submission.json Leaderboard/record metadata (metrics, seeds, artifact sizes).
records/track_10min_16mb/2026-03-25_CROWNQ_GPTQ_SlidingWindow/train_seed1337.log Training/eval log for seed 1337 supporting reported numbers.
records/track_10min_16mb/2026-03-25_CROWNQ_GPTQ_SlidingWindow/train_seed42.log Training/eval log for seed 42 supporting reported numbers.
records/track_10min_16mb/2026-03-25_CROWNQ_GPTQ_SlidingWindow/train_seed7.log Training/eval log for seed 7 supporting reported numbers.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

import torch.nn.functional as F
from torch import Tensor, nn
from torch.nn.parallel import DistributedDataParallel as DDP
from flash_attn_interface import flash_attn_func as flash_attn_3_func
Copy link

Copilot AI Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

flash_attn_interface is imported unconditionally, so the script will crash at startup in environments where FlashAttention 3 isn't installed (requirements.txt doesn’t list it, and other record scripts typically have a safe fallback). Consider wrapping this import in a try/except and falling back to PyTorch SDPA (or another available attention backend) when the import fails, so the record is reproducible across the standard evaluation environment.

Suggested change
from flash_attn_interface import flash_attn_func as flash_attn_3_func
try:
# Optional FlashAttention 3 backend; used when available.
from flash_attn_interface import flash_attn_func as flash_attn_3_func
HAS_FLASH_ATTN_3 = True
except Exception:
# Fallback to PyTorch scaled_dot_product_attention when FlashAttention 3 is not installed.
HAS_FLASH_ATTN_3 = False
def flash_attn_3_func(q: Tensor, k: Tensor, v: Tensor, *args, **kwargs) -> Tensor:
"""
Compatibility wrapper that mimics flash_attn_func using PyTorch SDPA.
Accepts extra *args/**kwargs for flexibility and ignores unsupported options.
"""
# Extract commonly used keyword arguments if present.
dropout_p = 0.0
causal = False
if args:
# Best-effort mapping for positional arguments commonly used with flash_attn_func:
# flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False, ...)
if len(args) >= 1:
dropout_p = float(args[0])
if len(args) >= 3:
causal = bool(args[2])
# Keyword arguments override positional ones if provided.
if "dropout_p" in kwargs:
dropout_p = float(kwargs["dropout_p"])
if "causal" in kwargs:
causal = bool(kwargs["causal"])
attn_mask = kwargs.get("attn_mask", None)
# PyTorch scaled_dot_product_attention expects (B, H, S, D).
orig_shape = q.shape
q_t = q
k_t = k
v_t = v
if q.dim() == 4:
# If layout is (B, S, H, D), transpose to (B, H, S, D).
if q.size(1) != q.size(2):
q_t = q.transpose(1, 2)
k_t = k.transpose(1, 2)
v_t = v.transpose(1, 2)
out = F.scaled_dot_product_attention(
q_t,
k_t,
v_t,
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=causal,
)
# Restore original layout if we transposed.
if out.shape != orig_shape and out.dim() == 4 and orig_shape == (orig_shape[0], orig_shape[1], orig_shape[2], orig_shape[3]):
# If we previously treated input as (B, S, H, D), transpose back.
if orig_shape[1] != orig_shape[2]:
out = out.transpose(1, 2)
return out

Copilot uses AI. Check for mistakes.
Comment on lines +6 to +11
"date": "2026-03-25T06:30:00Z",
"val_loss": 1.8886,
"val_loss_std": 0.0009,
"val_bpb": 1.1186,
"val_bpb_std": 0.0006,
"seeds": [1337, 42, 7],
Copy link

Copilot AI Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

submission.json deviates from the schema used by the other /records/track_10min_16mb/*/submission.json examples (e.g., missing fields like pre_quant_val_loss, pre_quant_val_bpb, step_stop, wallclock_seconds, eval_time_seconds, and a bytes_model_* breakdown). If any tooling expects the established keys, this new format may break ingestion; consider aligning to the existing schema and adding the additional fields while keeping the per-seed breakdown as extra metadata.

Copilot uses AI. Check for mistakes.
Comment on lines +1673 to +1687
# === TTT BURST: Late-stage sharpening on recent training data ===
if args.ttt_burst_enabled and hasattr(train_loader, '_ttt_buffer') and len(train_loader._ttt_buffer) > 0:
ttt_buffer = train_loader._ttt_buffer
log0(f"ttt_burst:start epochs:{args.ttt_burst_epochs} buffer_size:{len(ttt_buffer)} lr_factor:{args.ttt_burst_lr_factor}")
ttt_lr_scale = args.ttt_burst_lr_factor
for ttt_epoch in range(args.ttt_burst_epochs):
ttt_epoch_loss = 0.0
for ttt_i, (bx, by) in enumerate(ttt_buffer):
zero_grad_all()
for opt in optimizers:
for group in opt.param_groups:
group["lr"] = group["base_lr"] * ttt_lr_scale
with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True):
ttt_loss = model(bx, by)
(ttt_loss * grad_scale).backward()
Copy link

Copilot AI Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ttt_burst performs additional gradient steps on training data after the main loop stops due to MAX_WALLCLOCK_SECONDS. If the track’s “10 minute training” constraint is interpreted as total training compute/wallclock until final weights are produced, this effectively adds extra training beyond the cap. Consider either (a) accounting for ttt_burst time/steps inside the wallclock cap logic or (b) disabling TTT_BURST_ENABLED by default for the record run and documenting it explicitly.

Copilot uses AI. Check for mistakes.
Comment on lines +1738 to +1763
# GPTQ: collect Hessians for calibration-based quantization
hessians = None
if args.gptq_enabled:
log0(f"gptq:collecting hessians batches={args.gptq_calib_batches}")
t_hess = time.perf_counter()
calib_loader = DistributedTokenLoader(args.train_files, rank, world_size, device)
hessians = collect_hessians(
base_model, calib_loader, args, device, grad_accum_steps,
num_batches=args.gptq_calib_batches,
)
log0(f"gptq:hessians collected layers={len(hessians)} time={time.perf_counter() - t_hess:.1f}s")
del calib_loader
torch.cuda.empty_cache()
quant_result, quant_meta = mixed_quantize_int6(
sd_cpu, {"mlp", "attn"}, hessians=hessians, gptq_block_size=args.gptq_block_size,
)
# Selective +/-1 pruning: zero out least-impactful quantized values to fit target size
target_bytes = 16_000_000
code_bytes = len(code.encode("utf-8"))
target_model_bytes = target_bytes - code_bytes - 50_000 # headroom
def _serialize_and_compress(qr, qm):
buf = io.BytesIO()
torch.save({"w": qr, "m": qm}, buf)
return lzma.compress(buf.getvalue(), preset=6)
test_blob = _serialize_and_compress(quant_result, quant_meta)
log0(f"gptq:pre_prune artifact={len(test_blob)} target={target_model_bytes}")
Copy link

Copilot AI Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GPTQ calibration (Hessian collection) and the subsequent quantization/pruning are currently executed on every rank when running distributed. Since only rank 0 writes the artifacts, this multiplies the export-time compute and memory pressure by world_size for no benefit. Consider gating the Hessian collection + mixed_quantize_int6 + pruning to master_process and having other ranks wait at a barrier and then load the produced artifact for evaluation.

Suggested change
# GPTQ: collect Hessians for calibration-based quantization
hessians = None
if args.gptq_enabled:
log0(f"gptq:collecting hessians batches={args.gptq_calib_batches}")
t_hess = time.perf_counter()
calib_loader = DistributedTokenLoader(args.train_files, rank, world_size, device)
hessians = collect_hessians(
base_model, calib_loader, args, device, grad_accum_steps,
num_batches=args.gptq_calib_batches,
)
log0(f"gptq:hessians collected layers={len(hessians)} time={time.perf_counter() - t_hess:.1f}s")
del calib_loader
torch.cuda.empty_cache()
quant_result, quant_meta = mixed_quantize_int6(
sd_cpu, {"mlp", "attn"}, hessians=hessians, gptq_block_size=args.gptq_block_size,
)
# Selective +/-1 pruning: zero out least-impactful quantized values to fit target size
target_bytes = 16_000_000
code_bytes = len(code.encode("utf-8"))
target_model_bytes = target_bytes - code_bytes - 50_000 # headroom
def _serialize_and_compress(qr, qm):
buf = io.BytesIO()
torch.save({"w": qr, "m": qm}, buf)
return lzma.compress(buf.getvalue(), preset=6)
test_blob = _serialize_and_compress(quant_result, quant_meta)
log0(f"gptq:pre_prune artifact={len(test_blob)} target={target_model_bytes}")
if master_process:
# GPTQ: collect Hessians for calibration-based quantization
hessians = None
if args.gptq_enabled:
log0(f"gptq:collecting hessians batches={args.gptq_calib_batches}")
t_hess = time.perf_counter()
calib_loader = DistributedTokenLoader(args.train_files, rank, world_size, device)
hessians = collect_hessians(
base_model, calib_loader, args, device, grad_accum_steps,
num_batches=args.gptq_calib_batches,
)
log0(
f"gptq:hessians collected layers={len(hessians)} "
f"time={time.perf_counter() - t_hess:.1f}s"
)
del calib_loader
torch.cuda.empty_cache()
quant_result, quant_meta = mixed_quantize_int6(
sd_cpu, {"mlp", "attn"}, hessians=hessians, gptq_block_size=args.gptq_block_size,
)
# Selective +/-1 pruning: zero out least-impactful quantized values to fit target size
target_bytes = 16_000_000
code_bytes = len(code.encode("utf-8"))
target_model_bytes = target_bytes - code_bytes - 50_000 # headroom
def _serialize_and_compress(qr, qm):
buf = io.BytesIO()
torch.save({"w": qr, "m": qm}, buf)
return lzma.compress(buf.getvalue(), preset=6)
test_blob = _serialize_and_compress(quant_result, quant_meta)
log0(f"gptq:pre_prune artifact={len(test_blob)} target={target_model_bytes}")

Copilot uses AI. Check for mistakes.
Comment on lines +8 to +10
- **Architecture**: 11L, 512d, GQA 8H/4KV, MLP 3x LeakyReLU(0.5)^2, XSA on last 4 layers (7-10), VRL, BigramHash 3072, partial RoPE 16/64.
- **Eval**: Sliding window with stride=64. No test-time training.

Copy link

Copilot AI Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

README states “No test-time training” (eval is pure sliding-window inference), but the included logs show ttt:start / ttt_sliding:start being run. Please reconcile this by regenerating logs with TTT disabled, or clarifying in the README that the TTT section in the logs was a separate diagnostic run and not part of the reported score.

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants